{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "gpuClass": "standard",
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Train a language model from scratch\n",
        "\n",
        "txtai has a robust training pipeline that can fine-tune large language models (LLMs) for downstream tasks such as labeling text. txtai also has the ability to train language models from scratch.\n",
        "\n",
        "The vast majority of time, fine-tuning a LLM yields the best results. But when making significant changes to the structure of a model, training from scratch is often required.\n",
        "\n",
        "Examples of significant changes are:\n",
        "\n",
        "- Changing the vocabulary size\n",
        "- Changing the number of hidden dimensions\n",
        "- Changing the number of attention heads or layers\n",
        "\n",
        "This notebook will show how to build a new tokenizer and train a small language model (known as a micromodel) from scratch.\n"
      ],
      "metadata": {
        "id": "-xU9P9iSR-Cy"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Install dependencies\n",
        "\n",
        "Install `txtai` and all dependencies."
      ],
      "metadata": {
        "id": "shlUi2kKS7KT"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xEvX9vCpn4E0"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/neuml/txtai#egg=txtai[pipeline-train] datasets sentence-transformers onnxruntime onnx"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load dataset\n",
        "\n",
        "This example will use the `ag_news` dataset, which is a collection of news article headlines."
      ],
      "metadata": {
        "id": "408IyXzKFSiG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset = load_dataset(\"ag_news\", split=\"train\")"
      ],
      "metadata": {
        "id": "IQ_ns6YvFRm1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train the tokenizer\n",
        "\n",
        "The first step is to train the tokenizer. We could use an existing tokenizer but in this case, we want a smaller vocabulary.\n"
      ],
      "metadata": {
        "id": "-vNVSA2FQnKj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "def stream(batch=10000):\n",
        "    for x in range(0, len(dataset), batch):\n",
        "        yield dataset[x: x + batch][\"text\"]\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
        "tokenizer = tokenizer.train_new_from_iterator(stream(), vocab_size=500, length=len(dataset))\n",
        "tokenizer.model_max_length = 512\n",
        "\n",
        "tokenizer.save_pretrained(\"bert\")"
      ],
      "metadata": {
        "id": "LJ2FskiiQ_l_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's test the tokenizer."
      ],
      "metadata": {
        "id": "BOW5JlxYS3Rm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(tokenizer.tokenize(\"Red Sox defeat Yankees 5-3\"))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "slLtmzfbRuf6",
        "outputId": "2391b58b-7428-49a2-9225-0268a1f2ad64"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "['re', '##d', 'so', '##x', 'de', '##f', '##e', '##at', 'y', '##ank', '##e', '##es', '5', '-', '3']\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "With a limited vocabulary size of 500, most words require multiple tokens. This limited vocabulary lowers the number of token representations the model needs to learn."
      ],
      "metadata": {
        "id": "IozRdXdEqegD"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train the language model\n",
        "\n",
        "Now it's time to train the model. We'll train a micromodel, which is an extremely small language model with a limited vocabulary. Micromodels, when paired with a limited vocabulary have the potential to work in limited compute environments like edge devices and microcontrollers."
      ],
      "metadata": {
        "id": "gqEBeBEoTrup"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer, BertConfig, BertForMaskedLM\n",
        "\n",
        "from txtai.pipeline import HFTrainer\n",
        "\n",
        "config = BertConfig(\n",
        "    vocab_size = 500,\n",
        "    hidden_size = 50,\n",
        "    num_hidden_layers = 2,\n",
        "    num_attention_heads = 2,\n",
        "    intermediate_size = 100,\n",
        ")\n",
        "\n",
        "model = BertForMaskedLM(config)\n",
        "model.save_pretrained(\"bert\")\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"bert\")\n",
        "\n",
        "train = HFTrainer()\n",
        "\n",
        "# Train model\n",
        "train((model, tokenizer), dataset, task=\"language-modeling\", output_dir=\"bert\",\n",
        "      fp16=True, per_device_train_batch_size=128, num_train_epochs=10,\n",
        "      dataloader_num_workers=2)"
      ],
      "metadata": {
        "id": "MEpZAr0TUMCK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Sentence embeddings\n",
        "\n",
        "Next let's take the language model and fine-tune it to build sentence embeddings. "
      ],
      "metadata": {
        "id": "53bvB9c6MbPS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "!wget https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/examples/training/nli/training_nli_v2.py\n",
        "!python training_nli_v2.py bert\n",
        "!mv output/* bert-nli"
      ],
      "metadata": {
        "id": "f11f5tjfS85m"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Embeddings search\n",
        "\n",
        "Now we'll build a txtai embeddings index using the fine-tuned model. We'll index the `ag_news` dataset. "
      ],
      "metadata": {
        "id": "FTOm5ofaMmcv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from txtai.embeddings import Embeddings\n",
        "\n",
        "# Get list of all text\n",
        "texts = dataset[\"text\"]\n",
        "\n",
        "embeddings = Embeddings({\"path\": \"bert-nli\", \"content\": True})\n",
        "embeddings.index((x, text, None) for x, text in enumerate(texts))"
      ],
      "metadata": {
        "id": "_kKe5kRnVRhM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's run a search and see how much the model has learned."
      ],
      "metadata": {
        "id": "Rh9yA6ZJM47H"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "embeddings.search(\"Boston Red Sox Cardinals World Series\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XCRhpfLmV1-q",
        "outputId": "662f1b1d-fcf5-4383-fd8a-369081a77501"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'id': '76733',\n",
              "  'text': 'Red Sox sweep Cardinals to win World Series The Boston Red Sox ended their 86-year championship drought with a 3-0 win over the St. Louis Cardinals in Game Four of the World Series.',\n",
              "  'score': 0.8008379936218262},\n",
              " {'id': '71169',\n",
              "  'text': 'Red Sox lead 2-0 over Cardinals of World Series The host Boston Red Sox scored a 6-2 victory over the St. Louis Cardinals, helped by Curt Schilling #39;s pitching through pain and seeping blood, in World Series Game 2 on Sunday night.',\n",
              "  'score': 0.7896029353141785},\n",
              " {'id': '70100',\n",
              "  'text': 'Sports: Red Sox 9 Cardinals 7 after 7 innings BOSTON Boston has scored twice in the seventh inning to take an 9-to-7 lead over the St. Louis Cardinals in the World Series opener at Fenway Park.',\n",
              "  'score': 0.7735188603401184}]"
            ]
          },
          "metadata": {},
          "execution_count": 49
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Not too bad. It's far from perfect but we can tell that it has some knowledge! This model was trained for 5 minutes, there is certainly room for improvement in training longer and/or with a larger dataset.\n",
        "\n",
        "The standard `bert-base-uncased` model has 110M parameters and is around 440MB. Let's see how many parameters this model has."
      ],
      "metadata": {
        "id": "M5Pk1spcM72L"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Show number of parameters\n",
        "parameters = sum(p.numel() for p in embeddings.model.model.parameters())\n",
        "print(f\"Number of parameters:\\t\\t{parameters:,}\")\n",
        "print(f\"% of bert-base-uncased\\t\\t{(parameters / 110000000) * 100:.2f}%\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RIEnDwuxeakq",
        "outputId": "131930f1-e6cd-4b23-b9dc-a62e670eb8e4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of parameters:\t\t94,450\n",
            "% of bert-base-uncased\t\t0.09%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!ls -lh bert-nli/pytorch_model.bin"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ATg4aInQeRN-",
        "outputId": "a6181fbc-ee6c-426e-882e-1d6cbefa22a0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "-rw-r--r-- 1 root root 386K Jan 11 20:52 bert-nli/pytorch_model.bin\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "This model is 386KB and has only 0.1% of the parameters. With proper vocabulary selection, a small language model has potential."
      ],
      "metadata": {
        "id": "7GfsF8ziNFJa"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Quantization\n",
        "\n",
        "If 386KB isn't small enough, we can quantize the model to get it down even further. "
      ],
      "metadata": {
        "id": "CcbJNidNwuXt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from txtai.pipeline import HFOnnx\n",
        "\n",
        "onnx = HFOnnx()\n",
        "onnx(\"bert-nli\", task=\"pooling\", output=\"bert-nli.onnx\", quantize=True)"
      ],
      "metadata": {
        "id": "IYZnex9kRcb0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "embeddings = Embeddings({\"path\": \"bert-nli.onnx\", \"tokenizer\": \"bert-nli\", \"content\": True})\n",
        "embeddings.index((x, text, None) for x, text in enumerate(texts))\n",
        "embeddings.search(\"Boston Red Sox Cardinals World Series\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QL_1UosIVkZ7",
        "outputId": "6b2bedb8-5cc6-44b6-855a-617a6a07478c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'id': '76733',\n",
              "  'text': 'Red Sox sweep Cardinals to win World Series The Boston Red Sox ended their 86-year championship drought with a 3-0 win over the St. Louis Cardinals in Game Four of the World Series.',\n",
              "  'score': 0.8008379936218262},\n",
              " {'id': '71169',\n",
              "  'text': 'Red Sox lead 2-0 over Cardinals of World Series The host Boston Red Sox scored a 6-2 victory over the St. Louis Cardinals, helped by Curt Schilling #39;s pitching through pain and seeping blood, in World Series Game 2 on Sunday night.',\n",
              "  'score': 0.7896029353141785},\n",
              " {'id': '70100',\n",
              "  'text': 'Sports: Red Sox 9 Cardinals 7 after 7 innings BOSTON Boston has scored twice in the seventh inning to take an 9-to-7 lead over the St. Louis Cardinals in the World Series opener at Fenway Park.',\n",
              "  'score': 0.7735188603401184}]"
            ]
          },
          "metadata": {},
          "execution_count": 50
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!ls -lh bert-nli.onnx"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oCa3RDeInCkN",
        "outputId": "89cf379a-09cf-4fbc-8568-4d66085450f3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "-rw-r--r-- 1 root root 187K Jan 11 20:53 bert-nli.onnx\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We're down to 187KB with a quantized model!\n"
      ],
      "metadata": {
        "id": "r95T1HZXVhnZ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train on BERT dataset\n",
        "\n",
        "The [BERT paper](https://arxiv.org/abs/1810.04805) has all the information regarding training parameters and datasets used. Hugging Face Datasets hosts the `bookcorpus` and `wikipedia` datasets.\n",
        "\n",
        "Training on this size of a dataset is out of scope for this notebook but example code is shown below on how to build the BERT dataset.\n",
        "\n",
        "```python\n",
        "bookcorpus = load_dataset(\"bookcorpus\", split=\"train\")\n",
        "wiki = load_dataset(\"wikipedia\", \"20220301.en\", split=\"train\")\n",
        "wiki = wiki.remove_columns([col for col in wiki.column_names if col != \"text\"])\n",
        "dataset = concatenate_datasets([bookcorpus, wiki])\n",
        "```\n",
        "\n",
        "Then the same steps to train the tokenizer and model can be run. The dataset is 25GB compressed, so it will take some space and time to process!"
      ],
      "metadata": {
        "id": "aPaZsoxnYW8I"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Wrapping up\n",
        "\n",
        "This notebook covered how to build micromodels from scratch with txtai. Micromodels can be fully rebuilt in hours using the most up-to-date knowledge available. If properly constructed, prepared and trained, micromodels have the potential to be a viable choice for limited resource environments. They can also help when realtime response is more important than having the highest accuracy scores.\n",
        "\n",
        "It's our hope that further research and exploration into micromodels leads to productive and useful models."
      ],
      "metadata": {
        "id": "4L8smyyXc8q8"
      }
    }
  ]
}
